{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# Sample pipeline for text feature extraction and evaluation"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Adapted from http://scikit-learn.org/stable/auto_examples/model_selection/grid_search_text_feature_extraction.html"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The dataset used in this example is the 20 newsgroups dataset which will be automatically downloaded and then cached and reused for the document classification example.\n",
    "\n",
    "You can adjust the number of categories by giving their names to the dataset loader or setting them to None to get the 20 of them."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "collapsed": true
   },
   "outputs": [],
   "source": [
    "# Python Version Authors: Olivier Grisel <olivier.grisel@ensta.org>\n",
    "#                         Peter Prettenhofer <peter.prettenhofer@gmail.com>\n",
    "#                         Mathieu Blondel <mathieu@mblondel.org>\n",
    "# License: BSD 3 clause\n",
    "\n",
    "using ScikitLearn\n",
    "using ScikitLearn.Pipelines: Pipeline\n",
    "using ScikitLearn.GridSearch: GridSearchCV\n",
    "\n",
    "@sk_import datasets: fetch_20newsgroups\n",
    "@sk_import feature_extraction.text: (CountVectorizer, TfidfTransformer)\n",
    "@sk_import linear_model: SGDClassifier"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "collapsed": false
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Loading 20 newsgroups dataset for categories:ASCIIString[\"alt.atheism\",\"talk.religion.misc\"]857 documents\n",
      "2 categories\n",
      "\n",
      "Performing grid search...\n",
      "pipeline:Any[\"vect\",\"tfidf\",\"clf\"]\n",
      "parameters:\n",
      "Dict{ASCIIString,Any}(\"vect__max_df\"=>(0.5,0.75,1.0),\"clf__penalty\"=>(\"l2\",\"elasticnet\"),\"clf__alpha\"=>(1.0e-5,1.0e-6),\"vect__ngram_range\"=>((1,1),(1,2)))\n",
      "Fitting 3 folds for each of 24 candidates, totalling 72 fits\n",
      "done in 7.997s\n",
      "Best score: 0.860\n",
      "Best parameters set:\n",
      "\tvect__max_df: 1.0\n",
      "\tclf__penalty: l2\n",
      "\tclf__alpha: 1.0e-5\n",
      "\tvect__ngram_range: (1,2)\n"
     ]
    }
   ],
   "source": [
    "###############################################################################\n",
    "# Load some categories from the training set\n",
    "categories = [\n",
    "    \"alt.atheism\",\n",
    "    \"talk.religion.misc\",\n",
    "]\n",
    "# Uncomment the following to do the analysis on all the categories\n",
    "#categories = None\n",
    "\n",
    "print(\"Loading 20 newsgroups dataset for categories:\")\n",
    "print(categories)\n",
    "\n",
    "data = fetch_20newsgroups(subset=\"train\", categories=categories)\n",
    "filenames = data[\"filenames\"]\n",
    "target_names = data[\"target_names\"]\n",
    "println(\"$(length(filenames)) documents\")\n",
    "println(\"$(length(target_names)) categories\")\n",
    "println()\n",
    "\n",
    "cv = CountVectorizer()\n",
    "fit_transform!(cv, data[\"data\"][1:5], [0,0,1])\n",
    "\n",
    "###############################################################################\n",
    "# define a pipeline combining a text feature extractor with a simple\n",
    "# classifier\n",
    "pipeline = Pipeline([\n",
    "    (\"vect\", CountVectorizer()),\n",
    "    (\"tfidf\", TfidfTransformer()),\n",
    "    (\"clf\", SGDClassifier()),\n",
    "])\n",
    "\n",
    "# uncommenting more parameters will give better exploring power but will\n",
    "# increase processing time in a combinatorial way\n",
    "parameters = Dict(\n",
    "    \"vect__max_df\"=> (0.5, 0.75, 1.0),\n",
    "    #'vect__max_features': (None, 5000, 10000, 50000),\n",
    "    \"vect__ngram_range\"=> ((1, 1), (1, 2)),  # unigrams or bigrams\n",
    "    #'tfidf__use_idf': (True, False),\n",
    "    #'tfidf__norm': ('l1', 'l2'),\n",
    "    \"clf__alpha\"=> (0.00001, 0.000001),\n",
    "    \"clf__penalty\"=> (\"l2\", \"elasticnet\"),\n",
    "    #'clf__n_iter': (10, 50, 80),\n",
    ")\n",
    "\n",
    "# find the best parameters for both the feature extraction and the\n",
    "# classifier\n",
    "grid_search = GridSearchCV(pipeline, parameters, n_jobs=1, verbose=1)\n",
    "\n",
    "println(\"Performing grid search...\")\n",
    "println(\"pipeline:\", [name for (name, _) in pipeline.steps])\n",
    "println(\"parameters:\")\n",
    "println(parameters)\n",
    "t0 = time()\n",
    "N_samples = 100   # number of posts to train on\n",
    "fit!(grid_search, data[\"data\"][1:N_samples], data[\"target\"][1:N_samples])\n",
    "@printf(\"done in %0.3fs\", time() - t0)\n",
    "println()\n",
    "\n",
    "@printf(\"Best score: %0.3f\\n\", grid_search.best_score_)\n",
    "@printf(\"Best parameters set:\\n\")\n",
    "best_parameters = get_params(grid_search.best_estimator_)\n",
    "for param_name in keys(parameters)\n",
    "    println(\"\\t$param_name: $(best_parameters[param_name])\")\n",
    "end"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Julia 0.4.5",
   "language": "julia",
   "name": "julia-0.4"
  },
  "language_info": {
   "file_extension": ".jl",
   "mimetype": "application/julia",
   "name": "julia",
   "version": "0.4.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
